#!/usr/bin/env python3

import argparse
import collections
import concurrent.futures
import contextlib
import hashlib
import json
import logging
import os
import platform
import shutil

# permit ctrl+c to work correctly
import signal as _signal  # noqa
import subprocess
import sys
import tempfile
import time
import typing
import urllib.parse
import urllib.request
from datetime import datetime
from pathlib import Path


#
# ------------------------------------------------------------------------------
# LOGGING SETUP
# ------------------------------------------------------------------------------
#
class CustomFormatter(logging.Formatter):
    """
    A custom logging formatter that preserves the original script's
    date/time format, color codes, and bracketed icons for each level.
    """

    # When running on a TTY (interactive), use color codes. Otherwise no color.
    if sys.stderr.isatty():
        green = "\x1b[32m"
        yellow = "\x1b[33m"
        blue = "\x1b[34m"
        red = "\x1b[31m"
        bold_red = "\x1b[31;1m"
        reset = "\x1b[0m"
    else:
        green = ""
        yellow = ""
        blue = ""
        red = ""
        bold_red = ""
        reset = ""

    def __init__(self, one_line_logs: bool):
        super().__init__()
        self.one_line_logs = one_line_logs

    def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
        ts = record.created
        dt = datetime.fromtimestamp(ts)
        return dt.strftime("%Y/%m/%d | %H:%M:%S | ") + str(int(ts))

    def format(self, record: logging.LogRecord) -> str:
        # Assign an icon + color for each level
        if record.levelno == logging.DEBUG:
            icon = "[🐞]"
            color = self.blue
        elif record.levelno == logging.INFO:
            icon = "[ℹ️]"
            color = self.green
        elif record.levelno == logging.WARNING:
            icon = "[⚠️ Warning]"
            color = self.yellow
        elif record.levelno == logging.ERROR:
            icon = "[❌]"
            color = self.red
        elif record.levelno == logging.CRITICAL:
            icon = "[💥]"
            color = self.bold_red
        else:
            icon = f"[{record.levelname}]"
            color = self.reset

        #  YYYY/MM/DD | HH:MM:SS | EPOCH  [ICON]  message
        log_fmt = "%(color)s%(asctime)s %(icon)s %(message)s%(reset)s"

        # Use our custom formatter.
        formatter = logging.Formatter(fmt=log_fmt, datefmt=None)

        record.color = color
        record.icon = icon
        record.reset = self.reset

        formatter.formatTime = self.formatTime  # type: ignore
        output = formatter.format(record)
        return output


def conventional_logging(one_line_logs: bool, verbose: bool) -> None:
    """Sets up logging for the script."""
    root_logger = logging.getLogger()
    root_logger.handlers.clear()
    root_logger.setLevel(logging.DEBUG if verbose else logging.INFO)

    # Mute overly chatty third-party loggers if not verbose
    if not verbose:
        for chatty in ["httpcore", "urllib3", "httpx"]:
            logging.getLogger(chatty).setLevel(logging.WARNING)

    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG if verbose else logging.INFO)
    ch.setFormatter(CustomFormatter(one_line_logs=one_line_logs))
    root_logger.addHandler(ch)


class Dirs(object):
    def __init__(self, tmp_dir: Path, out_dir: Path, cdn_out: Path, dev_out: Path, proposal_out: Path):
        self.tmp_dir = tmp_dir
        self.out_dir = out_dir
        self.cdn_out = cdn_out
        self.dev_out = dev_out
        self.proposal_out = proposal_out


logger = logging.getLogger(__name__)


class VerificationError(RuntimeError):
    pass


def verify_sha256_against_sums(binary_path: Path, sums_file: Path, git_rev: str) -> str:
    """Verifies a file's SHA-256 against a provided SHA256SUMS file."""
    logger.debug(f"Verifying {binary_path} against {sums_file}")
    lines = sums_file.read_text(encoding="utf-8").splitlines()
    binary_filename = binary_path.name

    found_line = None
    for line in lines:
        parts = line.split()
        if len(parts) >= 2:
            file_part = parts[1].lstrip("*")
            if file_part == binary_filename:
                found_line = line
                break
    if not found_line:
        logger.info(f"Contents of {sums_file}:\n{''.join(lines)}")
        raise VerificationError(f"Couldn't find {binary_filename} in {sums_file}")

    listed_hash = found_line.split()[0]
    local_hash = compute_sha256(binary_path)
    if local_hash != listed_hash:
        raise VerificationError(
            f"The hash for {binary_path} doesn't match the CDN sha256 sum for the git revision: {git_rev}\n"
            f"Local: {local_hash}\n"
            f"Published: {listed_hash}"
        )
    return listed_hash


def compare_hashes(local_hash_value: str, remote_hash_value: str, os_type: str) -> bool:
    """
    Compares the local hash against the remote to verify reproducibility
    for the specified OS type.
    """
    if local_hash_value != remote_hash_value:
        logger.error(
            f"Error! The sha256 sum from the remote does not match the one we just built for {os_type}.\n"
            f"Local:   {local_hash_value}\n"
            f"Remote:  {remote_hash_value}"
        )
        return False
    else:
        logger.info(f"Verification successful for {os_type}!")
        logger.info(
            f"The sha256 sum for {os_type} from the artifact built locally and "
            f"the one fetched from remote match:\n"
            f"\tLocal = {local_hash_value}\n"
            f"\tRemote = {remote_hash_value}\n"
        )
        return True


def compute_sha256(file_path: Path) -> str:
    """Computes the SHA-256 hash of a file."""
    sha = hashlib.sha256()
    with file_path.open("rb") as f:
        for chunk in iter(lambda: f.read(64 * 1024), b""):
            sha.update(chunk)
    return sha.hexdigest()


progress_hidden: collections.deque[None] = collections.deque()
interrupted = False


def progress(
    prefix: str = "",
    out: typing.TextIO = sys.stderr,
) -> typing.Callable[[float], None]:
    def termsize() -> int:
        try:
            size = os.get_terminal_size()[0]
        except Exception:
            size = 79
        return size

    prev_progress_width: int | None = None

    def show(progress: float) -> None:
        nonlocal prev_progress_width
        size = termsize()

        pre = f"{prefix} "
        if not pre.strip():
            pre = ""
        post = ""
        size = size - len(pre) - len(post) - 2
        progress_width = int(round(progress * size))
        done_width = size - progress_width
        if prev_progress_width != progress_width:
            print(
                f"{pre}[{'█' * progress_width}{('.' * done_width)}]{post}",
                end="\r",
                file=out,
                flush=True,
            )
            prev_progress_width = progress_width

    def callme(progress: float) -> None:
        if sys.stderr.isatty() and len(progress_hidden) == 0:
            if progress >= 1.0:
                print(f"\r{' ' * (termsize())}", end="\r", flush=True, file=out)
            else:
                show(progress)

    return callme


def fetch_url_to_file(url: str, dest_path: Path) -> Path:
    """Downloads a file from URL to the destination path."""
    logger.debug(f"Downloading {url} to {dest_path}")
    dest_path.parent.mkdir(parents=True, exist_ok=True)
    tmp_file = dest_path.with_suffix(dest_path.suffix + ".part")

    try:
        req = urllib.request.Request(url, headers={"User-Agent": "ReproducibilityVerifier/1.0"})
        with urllib.request.urlopen(req) as response:
            try:
                length: int | None = int(response.headers["Content-Length"])
            except (IndexError, KeyError, ValueError):
                length = None
            sofar = 0
            progressbar = progress(str(dest_path))
            with tmp_file.open("wb") as out_file:
                while chunk := response.read(64 * 1024):
                    sofar = sofar + len(chunk)
                    out_file.write(chunk)
                    if length is not None:
                        progressbar(sofar / length)
                    if interrupted:
                        tmp_file.unlink()
                        return Path
        if length is not None and length != sofar:
            raise RuntimeError(
                f"Error downloading {url} -> {dest_path}. File is supposed to be {length} bytes, is {sofar} bytes."
            )
        tmp_file.rename(dest_path)
        logger.info("Downloaded %s to %s", url, dest_path)
        return dest_path
    except Exception as e:
        if tmp_file.is_file():
            tmp_file.unlink()
        raise RuntimeError(f"Could not download {url} -> {dest_path}. Error: {e}")


# ------------------------------------------------------------------------------
# ARGUMENT PARSING
# ------------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "This script builds and verifies reproducibility by comparing "
            "the images built by the CI vs. those built locally.\n\n"
            "Default behavior:\n"
            " - uses the current directory's HEAD commit\n"
            " - checks all OS images (GuestOS, HostOS, SetupOS)\n"
            " - compares hash sums against the artifacts from download.dfinity.systems and download.dfinity.network.\n\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    parser.add_argument("--guestos", action="store_true", help="Verify only GuestOS images.")
    parser.add_argument("--hostos", action="store_true", help="Verify only HostOS images.")
    parser.add_argument("--setupos", action="store_true", help="Verify only SetupOS images.")
    parser.add_argument("--recovery", action="store_true", help="Verify only recovery-guestos images.")
    parser.add_argument(
        "-p",
        "--proposal-id",
        type=str,
        default="",
        help="Proposal ID to check (for an Elect Replica or HostOS proposal).",
    )
    parser.add_argument(
        "-c", "--commit", type=str, default="", help="Git revision/commit to use from the IC repository."
    )
    parser.add_argument(
        "--cache-dir", type=Path, default=None, help="Cache directory to use (defaults to ~/.cache/repro-check)."
    )
    parser.add_argument("--clean", action="store_true", help="Clean up the download cache before running.")
    parser.add_argument("--debug", "--verbose", action="store_true", help="Enable debug mode output.")
    parser.add_argument(
        "--download-source",
        choices=["systems", "network", "both"],
        default="systems",
        help="Which source to download from: .systems, .network, or both (default: systems).",
    )

    args = parser.parse_args()
    # If user didn't specify any OS flags, check all four
    if not args.guestos and not args.hostos and not args.setupos and not args.recovery:
        args.guestos = True
        args.hostos = True
        args.setupos = True
        args.recovery = True

    return args


def get_download_sources(
    mode: typing.Literal["systems"] | typing.Literal["network"] | typing.Literal["both"],
) -> list[str]:
    """Returns the list of base CDN domains from which the images should be downloaded."""
    if mode == "systems":
        return ["download.dfinity.systems"]
    elif mode == "network":
        return ["download.dfinity.network"]
    elif mode == "both":
        return ["download.dfinity.systems", "download.dfinity.network"]
    else:
        raise ValueError(f"Invalid download source mode: {mode}")


Download = concurrent.futures.Future[Path]


# ------------------------------------------------------------------------------
# REPRODUCIBILITY VERIFIER
# ------------------------------------------------------------------------------
class ReproducibilityVerifier:
    def __init__(
        self,
        verify_guestos: bool,
        verify_hostos: bool,
        verify_setupos: bool,
        verify_recovery: bool,
        proposal_id: str,
        git_commit: str,
        download_source_mode: str,
        base_cache_dir: Path | None,
        clean_base_cache_dir: bool,
        keep_temp: bool,
    ):
        self.verify_guestos = verify_guestos
        self.verify_hostos = verify_hostos
        self.verify_setupos = verify_setupos
        self.verify_recovery = verify_recovery
        self.proposal_id = proposal_id
        self.git_commit = git_commit

        self.cdn_domains = get_download_sources(download_source_mode)
        logger.debug(f"CDNs selected: {self.cdn_domains}")

        self.git_hash = ""
        self.proposal_package_urls: list[str] = []
        self.proposal_package_sha256_hex = ""

        self.keep_temp = keep_temp

        # Cache
        if base_cache_dir is not None:
            if clean_base_cache_dir:
                raise RuntimeError(f"Refusing to clean manually specified base cache directory {base_cache_dir}")
            self.base_cache_dir = base_cache_dir
        else:
            self.base_cache_dir = base_cache_dir or Path(os.path.expanduser("~/.cache/repro-check"))
            if clean_base_cache_dir and base_cache_dir.is_dir():
                shutil.rmtree(base_cache_dir)
        self.cache_for_this_hash: Path = Path()

        self.download_executor = concurrent.futures.ThreadPoolExecutor(max_workers=9)

    # --------------------------------------------------------------------------
    # ENVIRONMENT CHECKS
    # --------------------------------------------------------------------------
    def check_architecture(self) -> None:
        if platform.machine() == "x86_64":
            logger.info("x86_64 architecture detected.")
        else:
            raise RuntimeError("Please run this script on x86_64 architecture.")

    def check_os_version(self) -> None:
        release_files = ["/etc/os-release", "/usr/lib/os-release"]
        os_info = {}
        for fpath in release_files:
            p = Path(fpath)
            if p.is_file():
                with p.open("r", encoding="utf-8") as f:
                    for line in f:
                        line = line.strip()
                        if "=" in line:
                            k, v = line.split("=", 1)
                            os_info[k] = v.strip('"')

        os_name = os_info.get("NAME", "")
        if os_name == "Ubuntu":
            logger.info("Ubuntu OS detected.")
        else:
            logger.warning("Please run this script on Ubuntu OS.")

        try:
            version_id = float(os_info.get("VERSION_ID", "0"))
            if version_id >= 22.04:
                logger.info("Ubuntu version ≥ 22.04 detected.")
            else:
                logger.warning("Please run this script on Ubuntu version 22.04 or higher.")
        except ValueError:
            pass

    def check_memory_at_least_gb(self, required_gb: int = 16) -> None:
        p = Path("/proc/meminfo")
        try:
            with p.open("r") as f:
                for line in f:
                    if line.startswith("MemTotal:"):
                        parts = line.split()
                        if len(parts) >= 2:
                            mem_kb = int(parts[1])
                            mem_gb = mem_kb // 1024 // 1024
                            if mem_gb < required_gb:
                                logger.warning(f"You need at least {required_gb} GiB of RAM on this machine.")
                            else:
                                logger.info(f"{required_gb} GiB or more RAM detected.")
                        return
            logger.warning("Could not detect memory from /proc/meminfo.")
        except Exception:
            logger.warning("Memory check failed. Could not parse /proc/meminfo.")

    def check_disk_at_least_gb(self, required_gb: int = 100) -> None:
        try:
            st = os.statvfs(".")
            free_bytes = st.f_bavail * st.f_frsize
            free_gb = free_bytes / (1024 * 1024 * 1024)
            if free_gb < required_gb:
                logger.warning(f"You need at least {required_gb} GiB of free disk space on this machine.")
            else:
                logger.info(f"{required_gb} GiB+ of free disk space detected.")
        except Exception:
            logger.warning("Disk check failed. Could not run statvfs on '.'.")

    def check_and_install_dependencies(self) -> None:
        deps = ["git", "podman"]
        logger.info("Checking and installing needed dependencies.")
        for d in deps:
            if shutil.which(d) is None:
                logger.info(f"Installing missing package: {d}")
                try:
                    subprocess.run(["sudo", "apt-get", "install", "-y", d], check=True)
                except subprocess.CalledProcessError:
                    raise RuntimeError(f"Failed to install {d}. Exiting.")
            else:
                logger.info(f"{d} is already installed.")

    def check_git_repo(self) -> None:
        logger.debug("Checking we are inside a Git repository.")
        try:
            cmd = ["git", "rev-parse", "--is-inside-work-tree"]
            inside = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
            if inside != "true":
                raise RuntimeError("Please run this script inside of a git repository.")
            logger.debug("Inside git repository.")
        except subprocess.CalledProcessError:
            RuntimeError("Please run this script inside of a git repository.")

    def check_ic_repo(self) -> None:
        logger.debug("Checking the repository is the IC repository.")
        try:
            cmd = ["git", "config", "--get", "remote.origin.url"]
            git_remote = subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
            if "/ic" not in git_remote:
                raise RuntimeError(
                    "When not specifying any option, please run this script inside an IC git repository."
                )
            logger.debug("Inside IC repository.")
        except subprocess.CalledProcessError:
            raise RuntimeError("When not specifying any option, please run this script inside an IC git repository.")

    def check_environment(self) -> None:
        self.check_architecture()
        self.check_os_version()
        self.check_memory_at_least_gb(16)
        self.check_disk_at_least_gb(100)
        self.check_and_install_dependencies()

    # --------------------------------------------------------------------------
    # Cache management + download scheduling
    # --------------------------------------------------------------------------
    def init_cache(self) -> None:
        """Initializes the cache directory and cleans old caches if needed."""
        self.base_cache_dir.mkdir(parents=True, exist_ok=True)
        keep = 2
        entries = [sub for sub in self.base_cache_dir.iterdir() if sub.is_dir()]
        entries.sort(key=lambda p: p.stat().st_mtime, reverse=True)
        for older in entries[keep:]:
            logger.debug(f"Removing old cache: {older}")
            shutil.rmtree(older, ignore_errors=True)
        self.cache_for_this_hash = self.base_cache_dir / self.git_hash
        if not self.cache_for_this_hash.exists():
            self.cache_for_this_hash.mkdir(parents=True)
        self.cache_for_this_hash.touch()
        logger.info("Using cache directory: %s", self.cache_for_this_hash)

    def cached_download(self, url: str, target_file: Path, os_type: str) -> Path:
        """Downloads a file if it’s not already in the cache."""
        dest_dir = self.cache_for_this_hash / urllib.parse.urlparse(url).netloc.split(":")[0] / os_type
        dest_dir.mkdir(parents=True, exist_ok=True)

        cache_file = dest_dir / target_file.name
        if not (cache_file.exists() and cache_file.stat().st_size > 0):
            logger.debug(f"Cache miss, downloading: {url}")
            fetch_url_to_file(url, cache_file)
        else:
            logger.debug(f"File already cached, skipping download: {url}")

        target_file.parent.mkdir(parents=True, exist_ok=True)
        if target_file.exists():
            target_file.unlink()
        os.link(cache_file, target_file)
        return target_file

    def start_download(self, url: str, target_file: Path, os_type: str, append_to_futures: bool = True) -> Download:
        return self.download_executor.submit(self.cached_download, url, target_file, os_type)

    # --------------------------------------------------------------------------
    # Core logic restructured to start downloads early
    # --------------------------------------------------------------------------
    def process_proposal(self) -> tuple[str | None, str | None]:
        """If a proposal ID is provided, fetch the proposal data and set internal state."""
        if not self.proposal_id:
            return (None, None)
        proposal_url = f"https://ic-api.internetcomputer.org/api/v3/proposals/{self.proposal_id}"
        logger.debug(f"Fetching proposal {proposal_url}")
        try:
            req = urllib.request.Request(proposal_url, headers={"User-Agent": "ReproducibilityVerifier/1.0"})
            with urllib.request.urlopen(req) as resp:
                if not (200 <= resp.status < 300):
                    raise RuntimeError(f"Could not fetch proposal {self.proposal_id}, HTTP code {resp.status}")
                data = resp.read()
        except Exception as e:
            err = f"Could not fetch {self.proposal_id}. Error: {e}"
            raise RuntimeError(err)

        proposal_data = json.loads(data)
        self.proposal_package_urls = proposal_data["payload"]["release_package_urls"]
        self.proposal_package_sha256_hex = proposal_data["payload"]["release_package_sha256_hex"]

        prop_str = json.dumps(proposal_data)
        if "replica_version_to_elect" in prop_str:
            self.git_hash = proposal_data["payload"]["replica_version_to_elect"]
            return (self.proposal_package_sha256_hex, None)
        elif "hostos_version_to_elect" in prop_str:
            self.git_hash = proposal_data["payload"]["hostos_version_to_elect"]
            return (None, self.proposal_package_sha256_hex)
        else:
            err = f"Proposal #{self.proposal_id} is missing replica_version_to_elect or hostos_version_to_elect"
            raise VerificationError(err)

    def decide_git_hash(self) -> None:
        """Determines which git hash to build and verify."""
        if self.proposal_id:
            return
        if self.git_commit:
            self.git_hash = self.git_commit
        else:
            self.check_git_repo()
            self.git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()

    def start_cdn_downloads_for_os(self, os_type: str, storage: Dirs) -> list[Download]:
        """Queues downloading of the OS artifact and sums file from the selected CDNs."""
        if os_type == "setup-os":
            path_component = "setup-os"
            artifact = "disk-img"
            subdir_name = "disk-img"
        elif os_type == "recovery-guestos":
            path_component = "guest-os"
            artifact = "update-img"
            subdir_name = "update-img-recovery"
        else:
            path_component = os_type
            artifact = "update-img"
            subdir_name = artifact

        tar_name = f"{artifact}.tar.zst"
        downloads: list[Download] = []
        for cdn_domain in self.cdn_domains:
            artifact_url = f"https://{cdn_domain}/ic/{self.git_hash}/{path_component}/{subdir_name}/{tar_name}"
            sums_url = f"https://{cdn_domain}/ic/{self.git_hash}/{path_component}/{subdir_name}/SHA256SUMS"

            subdir = storage.cdn_out / cdn_domain / os_type
            local_artifact_path = subdir / tar_name
            local_sums_path = subdir / "SHA256SUMS"

            downloads.extend(
                [
                    self.start_download(artifact_url, local_artifact_path, os_type),
                    self.start_download(sums_url, local_sums_path, os_type),
                ]
            )
        return downloads

    def start_proposal_download_if_needed(self, storage: Dirs) -> list[Download]:
        """If a proposal is set, download the package from the proposal-specified URLs."""
        downloads: list[Download] = []
        if not self.proposal_id:
            return downloads
        for url in self.proposal_package_urls:
            parsed = urllib.parse.urlparse(url)
            domain = parsed.netloc.split(":")[0]
            proposal_target = storage.proposal_out / domain / "update-img.tar.zst"
            downloads.append(self.start_download(url, proposal_target, "proposal", append_to_futures=False))
        return downloads

    @contextlib.contextmanager
    def storage(self, keep_temp: bool) -> typing.Generator[Dirs, None, None]:
        """Prepares temporary and output directories."""
        if keep_temp:
            logger.debug("DEBUG mode => not automatically removing the tempdir.")

        # Create temporary directory under cache location
        tmpdir = Path(self.base_cache_dir / f"tmp-{int(time.time())}")
        tmpdir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Using temporary directory: {tmpdir}")

        def cleanup():
            if not keep_temp and tmpdir.exists():
                logger.debug("Cleaning up temporary directory.")
                shutil.rmtree(tmpdir)

        out_dir = tmpdir / "disk-images" / self.git_hash

        yield Dirs(
            tmpdir,
            out_dir,
            out_dir / "cdn-img",
            out_dir / "dev-img",
            out_dir / "proposal-img",
        )
        if not keep_temp:
            cleanup()

    # --------------------------------------------------------------------------
    # Verification steps
    # --------------------------------------------------------------------------
    def verify_proposal_artifacts(self, downloads: list[Download] = []) -> None:
        """Verifies proposal artifact SHA-256 if a proposal is specified."""
        for completed_download in concurrent.futures.as_completed(downloads):
            proposal_target = completed_download.result()
            actual_hash = compute_sha256(proposal_target)
            if actual_hash != self.proposal_package_sha256_hex:
                raise VerificationError(
                    f"The proposal's artifact hash for {proposal_target} does not match!\n"
                    f"Expected: {self.proposal_package_sha256_hex}\n"
                    f"Actual:   {actual_hash}"
                )
        logger.info("The proposal's artifact and hash match.")

    def compare_cdn_hash(self, os_type: str, storage: Dirs) -> str:
        """Verifies the artifact hash from all specified CDNs for a given OS type."""
        if os_type == "setup-os":
            artifact = "disk-img"
        else:
            artifact = "update-img"

        final_hashes = []
        for cdn_domain in self.cdn_domains:
            subdir = storage.cdn_out / cdn_domain / os_type
            local_sums_path = subdir / "SHA256SUMS"
            binary_file_path = subdir / f"{artifact}.tar.zst"
            h = verify_sha256_against_sums(binary_file_path, local_sums_path, self.git_hash)
            final_hashes.append(h)

        if len(set(final_hashes)) != 1:
            raise VerificationError(f"The sources for {os_type} do not all match! {final_hashes}")
        return final_hashes[0]

    def compare_proposal_vs_cdn(
        self,
        storage: Dirs,
        guest_os_hash: str | None,
        guest_os_downloads: list[Download],
        host_os_hash: str | None,
        host_os_downloads: list[Download],
    ) -> None:
        """Compares the proposal’s artifact hash against the CDN-stored hash if a proposal is specified."""
        if not self.proposal_id:
            return
        if guest_os_hash:
            [f.result() for f in concurrent.futures.as_completed(guest_os_downloads)]
            cdn_hash = self.compare_cdn_hash("guest-os", storage)
            if cdn_hash != guest_os_hash:
                raise VerificationError(
                    "The sha256 sum from the proposal does not match the one from the CDN storage for GuestOS.\n"
                    f"Proposal sum: {guest_os_hash}\n"
                    f"CDN sum:      {cdn_hash}"
                )
            else:
                logger.info("The GuestOS sha256sum from the proposal and remote match.")
        if host_os_hash:
            [f.result() for f in concurrent.futures.as_completed(host_os_downloads)]
            cdn_hash = self.compare_cdn_hash("host-os", storage)
            if cdn_hash != host_os_hash:
                raise VerificationError(
                    "The sha256 sum from the proposal does not match the one from the CDN storage for HostOS.\n"
                    f"Proposal sum: {host_os_hash}\n"
                    f"CDN sum:      {cdn_hash}"
                )
            else:
                logger.info("The HostOS sha256sum from the proposal and remote match.")

    def clone_and_checkout_repo(self, ic_clone_path: Path) -> None:
        """Clones and checks out the IC repository at the desired commit."""
        ic_clone_path_cache = self.base_cache_dir / "repo"
        if os.getenv("CI") is not None:
            logger.info(f"Copying IC repository from {Path.cwd()} to temporary directory.")
            subprocess.run(["git", "clone", str(Path.cwd()), str(ic_clone_path)], check=True)
        else:
            logger.info("Cloning IC repository from GitHub.")
            try:
                subprocess.run(
                    ["git", "-C", str(ic_clone_path_cache), "fsck"],
                    check=True,
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.DEVNULL,
                )
            except subprocess.CalledProcessError:
                logger.debug(f"Git fsck failed in cache {ic_clone_path_cache}, removing IC repo git cache.")
                shutil.rmtree(ic_clone_path_cache, ignore_errors=True)
            subprocess.run(
                [
                    "git",
                    "clone",
                    "--reference-if-able",
                    str(ic_clone_path_cache),
                    "--dissociate",
                    "https://github.com/dfinity/ic",
                    str(ic_clone_path),
                ],
                check=True,
            )
        os.chdir(ic_clone_path)

        logger.info(f"Fetching {self.git_hash} from origin.")
        try:
            subprocess.run(["git", "fetch", "--quiet", "origin", self.git_hash], check=True)
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"Could not fetch commit {self.git_hash} from the remote repository.")

        if self.git_commit:
            self.check_git_repo()
            self.check_ic_repo()
            try:
                subprocess.run(["git", "cat-file", "-e", f"{self.git_commit}^{{commit}}"], check=True)
            except subprocess.CalledProcessError:
                raise RuntimeError(
                    "When specifying -c, please provide a valid git hash on a branch of the IC repository."
                )

        logger.info(f"Checking out {self.git_hash}.")
        subprocess.run(["git", "checkout", "--quiet", self.git_hash], check=True)

        shutil.rmtree(ic_clone_path_cache, ignore_errors=True)
        shutil.copytree(ic_clone_path, ic_clone_path_cache)

    def build_locally(
        self,
        storage: Dirs,
    ) -> None:
        """Builds the OS images locally, then compares their hashes to the remote (CDN) artifacts."""
        global progress_hidden

        progress_hidden.append(None)
        try:
            ic_clone_path = storage.tmp_dir / "ic"
            self.clone_and_checkout_repo(ic_clone_path)

            logger.info("Building IC-OS (./ci/container/build-ic.sh --icos).")
            subprocess.run(["./ci/container/build-ic.sh", "--icos"], check=True)
            logger.info("IC-OS build complete.")
        finally:
            progress_hidden.popleft()

        artifacts_path = ic_clone_path / "artifacts" / "icos"

        def move_artifact(path: str) -> None:
            src = artifacts_path / path
            dst = storage.dev_out / path
            dst.parent.mkdir(parents=True, exist_ok=True)
            shutil.move(str(src), str(dst))

        if self.verify_guestos:
            move_artifact("guestos/update/update-img.tar.zst")
        if self.verify_hostos:
            move_artifact("hostos/update/update-img.tar.zst")
        if self.verify_setupos:
            move_artifact("setupos/disk-img.tar.zst")
        if self.verify_recovery:
            move_artifact("guestos/update-img-recovery/update-img.tar.zst")

    def compare_with_local_build(
        self,
        storage: Dirs,
        compile_future: concurrent.futures.Future[None],
        guest_os_downloads: list[Download],
        host_os_downloads: list[Download],
        setup_os_downloads: list[Download],
        recovery_downloads: list[Download],
    ) -> None:
        compile_future.result()
        logger.info("Verifying locally built artifacts against remote CDN artifacts.")
        if guest_os_downloads:
            [f.result() for f in concurrent.futures.as_completed(guest_os_downloads)]
            local_path = storage.dev_out / "guestos" / "update" / "update-img.tar.zst"
            local_hash = compute_sha256(local_path)
            cdn_hash = self.compare_cdn_hash("guest-os", storage)
            compare_hashes(local_hash, cdn_hash, "GuestOS")

        if host_os_downloads:
            [f.result() for f in concurrent.futures.as_completed(host_os_downloads)]
            local_path = storage.dev_out / "hostos" / "update" / "update-img.tar.zst"
            local_hash = compute_sha256(local_path)
            cdn_hash = self.compare_cdn_hash("host-os", storage)
            compare_hashes(local_hash, cdn_hash, "HostOS")

        if setup_os_downloads:
            [f.result() for f in concurrent.futures.as_completed(setup_os_downloads)]
            local_path = storage.dev_out / "setupos" / "disk-img.tar.zst"
            local_hash = compute_sha256(local_path)
            cdn_hash = self.compare_cdn_hash("setup-os", storage)
            compare_hashes(local_hash, cdn_hash, "SetupOS")

        if recovery_downloads:
            [f.result() for f in concurrent.futures.as_completed(recovery_downloads)]
            local_path = storage.dev_out / "guestos" / "update-img-recovery" / "update-img.tar.zst"
            local_hash = compute_sha256(local_path)
            cdn_hash = self.compare_cdn_hash("recovery-guestos", storage)
            compare_hashes(local_hash, cdn_hash, "Recovery-GuestOS")

    # --------------------------------------------------------------------------
    # MAIN RUN
    # --------------------------------------------------------------------------
    def run(self) -> None:
        """Main entry point for reproducibility verification logic."""
        start_time = time.time()

        with self.storage(self.keep_temp) as dirs:
            guest_os_hash, host_os_hash = self.process_proposal()
            self.decide_git_hash()
            self.init_cache()

            downloads_for_proposal = self.start_proposal_download_if_needed(dirs)
            guest_os_downloads = self.start_cdn_downloads_for_os("guest-os", dirs) if self.verify_guestos else []
            host_os_downloads = self.start_cdn_downloads_for_os("host-os", dirs) if self.verify_hostos else []
            setup_os_downloads = self.start_cdn_downloads_for_os("setup-os", dirs) if self.verify_setupos else []
            recovery_downloads = self.start_cdn_downloads_for_os("recovery-guestos", dirs) if self.verify_recovery else []
            build = self.download_executor.submit(self.build_locally, dirs)

            # Environment checks can happen while downloads progress
            self.check_environment()

            # Verifications after downloads.  They take futures and await for them to be finished.
            self.verify_proposal_artifacts(downloads_for_proposal)
            self.compare_proposal_vs_cdn(dirs, guest_os_hash, guest_os_downloads, host_os_hash, host_os_downloads)
            self.compare_with_local_build(dirs, build, guest_os_downloads, host_os_downloads, setup_os_downloads, recovery_downloads)

        elapsed = time.time() - start_time
        h, rem = divmod(elapsed, 3600)
        m, s = divmod(rem, 60)
        logger.info(f"Total time: {int(h)}h {int(m)}m {int(s)}s")


# ------------------------------------------------------------------------------
# MAIN
# ------------------------------------------------------------------------------
def main() -> None:
    args = parse_args()
    conventional_logging(one_line_logs=True, verbose=args.debug)

    verifier = ReproducibilityVerifier(
        verify_guestos=args.guestos,
        verify_hostos=args.hostos,
        verify_setupos=args.setupos,
        verify_recovery=args.recovery,
        proposal_id=args.proposal_id,
        git_commit=args.commit,
        download_source_mode=args.download_source,
        base_cache_dir=args.cache_dir,
        clean_base_cache_dir=bool(args.clean),
        keep_temp=bool(args.debug),
    )
    try:
        verifier.run()
    except VerificationError as e:
        logger.critical("%s", e)
        sys.exit(1)
    except RuntimeError as e:
        logger.critical("%s", e)
        sys.exit(2)
    except KeyboardInterrupt:
        logger.info("Received interrupt signal.")
        global interrupted
        interrupted = True
        raise
    finally:
        verifier.download_executor.shutdown(cancel_futures=True)


if __name__ == "__main__":
    main()
